{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import os\n",
    "os.sys.path.insert(0, '/home/schirrmr/braindecode/code/braindecode/')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cropped Decoding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will use cropped decoding. Cropped decoding means the ConvNet is trained on time windows/time crops within the trials. We will explain this visually by comparing trialwise to cropped decoding.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Trialwise Decoding | Cropped Decoding\n",
    "- | - \n",
    "![Trialwise Decoding](./trialwise_explanation.png \"Trialwise Decoding\") | ![Cropped Decoding](./cropped_explanation.png \"Cropped Decoding\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On the left, you see trialwise decoding:\n",
    "\n",
    "1. A complete trial is pushed through the network\n",
    "2. The network produces a prediction\n",
    "3. The prediction is compared to the target (label) for that trial to compute the loss\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On the right, you see cropped decoding:\n",
    "\n",
    "1. Instead of a complete trial, windows within the trial, here called *crops*, are pushed through the network\n",
    "2. For computational efficiency, multiple neighbouring crops are pushed through the network simultaneously (these neighbouring crops are called a *supercrop*)\n",
    "3. Therefore, the network produces multiple predictions (one per crop in the supercrop)\n",
    "4. The individual crop predictions are averaged before computing the loss function\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notes:\n",
    "\n",
    "* The network architecture implicitly defines the crop size (it is the receptive field size, i.e., the number of timesteps the network uses to make a single prediction)\n",
    "* The supercrop size is a user-defined hyperparameter, called `input_time_length` in Braindecode. It mostly affects runtime (larger supercrop sizes should be faster). As a rule of thumb, you can set it to two times the crop size.\n",
    "* Crop size and supercrop size together define how many predictions the network makes per supercrop:  $\\mathrm{\\#supercrop}-\\mathrm{\\#crop}+1=\\mathrm{\\#predictions}$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For cropped decoding, the above training setup is mathematically identical to sampling crops in your dataset, pushing them through the network and training directly on the individual crops. At the same time, the above training setup is much faster as it avoids redundant computations by using dilated convolutions, see our paper [Deep learning with convolutional neural networks for EEG decoding and visualization](https://arxiv.org/abs/1703.05051). However, the two setups are only mathematically identical in case (1) your network does not use any padding and (2) your loss function leads to the same gradients when using the averaged output. The first is true for our shallow and deep ConvNet models and the second is true for the log-softmax outputs and negative log likelihood loss that is typically used for classification in PyTorch."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Most of the code for cropped decoding is identical to the [Trialwise Decoding Tutorial](Trialwise_Decoding.html), differences are explained in the text."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Enable logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "import importlib\n",
    "importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195\n",
    "log = logging.getLogger()\n",
    "log.setLevel('INFO')\n",
    "import sys\n",
    "logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',\n",
    "                     level=logging.INFO, stream=sys.stdout)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mne\n",
    "from mne.io import concatenate_raws\n",
    "\n",
    "# 5,6,7,10,13,14 are codes for executed and imagined hands/feet\n",
    "subject_id = 22\n",
    "event_codes = [5,6,9,10,13,14]\n",
    "#event_codes = [3,4,5,6,7,8,9,10,11,12,13,14]\n",
    "\n",
    "# This will download the files if you don't have them yet,\n",
    "# and then return the paths to the files.\n",
    "physionet_paths = mne.datasets.eegbci.load_data(subject_id, event_codes)\n",
    "\n",
    "# Load each of the files\n",
    "parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto', verbose='WARNING')\n",
    "         for path in physionet_paths]\n",
    "\n",
    "# Concatenate them\n",
    "raw = concatenate_raws(parts)\n",
    "\n",
    "# Find the events in this dataset\n",
    "events, _ = mne.events_from_annotations(raw)\n",
    "\n",
    "# Use only EEG channels\n",
    "eeg_channel_inds = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,\n",
    "                   exclude='bads')\n",
    "\n",
    "# Extract trials, only using EEG channels\n",
    "epoched = mne.Epochs(raw, events, dict(hands_or_left=2, feet_or_right=3), tmin=1, tmax=4.1, proj=False, picks=eeg_channel_inds,\n",
    "                baseline=None, preload=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert data to Braindecode format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "# Convert data from volt to millivolt\n",
    "# Pytorch expects float32 for input and int64 for labels.\n",
    "X = (epoched.get_data() * 1e6).astype(np.float32)\n",
    "y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.datautil.signal_target import SignalAndTarget\n",
    "\n",
    "train_set = SignalAndTarget(X[:40], y=y[:40])\n",
    "valid_set = SignalAndTarget(X[40:70], y=y[40:70])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "\n",
    "As in the trialwise decoding tutorial, we will use the Braindecode model class directly to perform the training in a few lines of code. If you instead want to use your own training loop, have a look at the [Cropped Manual Training Loop Tutorial](./Cropped_Manual_Training_Loop.html).\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For cropped decoding, we now transform the model into a model that outputs a dense time series of predictions.\n",
    "For this, we manually set the length of the final convolution layer to some length that makes the receptive field of the ConvNet smaller than the number of samples in a trial (see `final_conv_length=12` in the model definition). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "from braindecode.models.shallow_fbcsp import ShallowFBCSPNet\n",
    "from torch import nn\n",
    "from braindecode.torch_ext.util import set_random_seeds\n",
    "\n",
    "# Set if you want to use GPU\n",
    "# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.\n",
    "cuda = False\n",
    "set_random_seeds(seed=20170629, cuda=cuda)\n",
    "n_classes = 2\n",
    "in_chans = train_set.X.shape[1]\n",
    "\n",
    "model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,\n",
    "                        input_time_length=None,\n",
    "                        final_conv_length=12)\n",
    "if cuda:\n",
    "    model.cuda()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we supply `cropped=True` to our compile function "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from braindecode.torch_ext.optimizers import AdamW\n",
    "import torch.nn.functional as F\n",
    "#optimizer = AdamW(model.parameters(), lr=1*0.01, weight_decay=0.5*0.001) # these are good values for the deep model\n",
    "optimizer = AdamW(model.parameters(), lr=0.0625 * 0.01, weight_decay=0)\n",
    "model.compile(loss=F.nll_loss, optimizer=optimizer,  iterator_seed=1, cropped=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For fitting, we must supply the super crop size. Here, we it to 450 by setting `input_time_length = 450`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_time_length = 450\n",
    "model.fit(train_set.X, train_set.y, epochs=30, batch_size=64, scheduler='cosine',\n",
    "          input_time_length=input_time_length,\n",
    "         validation_data=(valid_set.X, valid_set.y),)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.epochs_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Eventually, we arrive at 90% accuracy, so 27 from 30 trials are correctly predicted."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_set = SignalAndTarget(X[70:], y=y[70:])\n",
    "\n",
    "model.evaluate(test_set.X, test_set.y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now predict entire trials as before or individual crops."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.predict_classes(test_set.X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.predict_outs(test_set.X)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For individual crops, provide `individual_crops=True`. This for example can be useful to plot accuracies over time:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import cm\n",
    "import matplotlib\n",
    "%matplotlib inline\n",
    "matplotlib.style.use('seaborn')\n",
    "labels_per_trial_per_crop = model.predict_classes(test_set.X, individual_crops=True)\n",
    "accs_per_crop = [l == y for l,y in zip(labels_per_trial_per_crop, test_set.y)]\n",
    "accs_per_crop = np.mean(accs_per_crop, axis=0)\n",
    "plt.figure(figsize=(8,3))\n",
    "plt.plot(accs_per_crop * 100)\n",
    "plt.title(\"Accuracies per timestep\", fontsize=16)\n",
    "plt.xlabel('Timestep in trial', fontsize=14)\n",
    "plt.ylabel('Accuracy [%]', fontsize=14)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Raw outputs could be used to visualize a prediction probability across a trial. Here we look at the first trial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cropped_outs = model.predict_outs(test_set.X, individual_crops=True)\n",
    "i_trial = 0\n",
    "# log-softmax outputs need to be exponentiated to get probabilities\n",
    "cropped_probs = np.exp(cropped_outs[i_trial])\n",
    "plt.figure(figsize=(8,3))\n",
    "plt.plot(cropped_probs.T)\n",
    "plt.title(\"Network probabilities for trial {:d} of class {:d}\".format(\n",
    "    i_trial, test_set.y[i_trial]), fontsize=16)\n",
    "plt.legend((\"Class 0\", \"Class 1\"), fontsize=12)\n",
    "plt.xlabel(\"Timestep within trial\", fontsize=14)\n",
    "plt.ylabel(\"Probabilities\", fontsize=14)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset references\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " This dataset was created and contributed to PhysioNet by the developers of the [BCI2000](http://www.schalklab.org/research/bci2000) instrumentation system, which they used in making these recordings. The system is described in:\n",
    " \n",
    "     Schalk, G., McFarland, D.J., Hinterberger, T., Birbaumer, N., Wolpaw, J.R. (2004) BCI2000: A General-Purpose Brain-Computer Interface (BCI) System. IEEE TBME 51(6):1034-1043.\n",
    "\n",
    "[PhysioBank](https://physionet.org/physiobank/) is a large and growing archive of well-characterized digital recordings of physiologic signals and related data for use by the biomedical research community and further described in:\n",
    "\n",
    "    Goldberger AL, Amaral LAN, Glass L, Hausdorff JM, Ivanov PCh, Mark RG, Mietus JE, Moody GB, Peng C-K, Stanley HE. (2000) PhysioBank, PhysioToolkit, and PhysioNet: Components of a New Research Resource for Complex Physiologic Signals. Circulation 101(23):e215-e220."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
